package org.neuroph.nnet.learning;

import com.google.firebase.remoteconfig.FirebaseRemoteConfig;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;

/* loaded from: classes2.dex */
public class ResilientPropagation extends BackPropagation {
    private static final double ZERO_TOLERANCE = 1.0E-27d;
    private double decreaseFactor = 0.5d;
    private double increaseFactor = 1.2d;
    private double initialDelta = 0.1d;
    private double maxDelta = 1.0d;
    private double minDelta = 1.0E-6d;

    /* loaded from: classes2.dex */
    public class ResilientWeightTrainingtData {
        public double gradient;
        public double previousDelta;
        public double previousGradient;
        public double previousWeightChange;

        public ResilientWeightTrainingtData() {
            this.previousDelta = ResilientPropagation.this.initialDelta;
        }
    }

    public ResilientPropagation() {
        super.setBatchMode(true);
    }

    private int sign(double d) {
        if (Math.abs(d) < ZERO_TOLERANCE) {
            return 0;
        }
        return d > FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE ? 1 : -1;
    }

    @Override // org.neuroph.core.learning.SupervisedLearning
    protected void doBatchWeightsUpdate() {
        Layer[] layers = this.neuralNetwork.getLayers();
        for (int layersCount = this.neuralNetwork.getLayersCount() - 1; layersCount > 0; layersCount--) {
            for (Neuron neuron : layers[layersCount].getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    resillientWeightUpdate(connection.getWeight());
                }
            }
        }
    }

    public double getDecreaseFactor() {
        return this.decreaseFactor;
    }

    public double getIncreaseFactor() {
        return this.increaseFactor;
    }

    public double getInitialDelta() {
        return this.initialDelta;
    }

    public double getMaxDelta() {
        return this.maxDelta;
    }

    public double getMinDelta() {
        return this.minDelta;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.SupervisedLearning, org.neuroph.core.learning.IterativeLearning, org.neuroph.core.learning.LearningRule
    public void onStart() {
        super.onStart();
        for (Layer layer : this.neuralNetwork.getLayers()) {
            for (Neuron neuron : layer.getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    connection.getWeight().setTrainingData(new ResilientWeightTrainingtData());
                }
            }
        }
    }

    protected void resillientWeightUpdate(Weight weight) {
        double d;
        ResilientWeightTrainingtData resilientWeightTrainingtData = (ResilientWeightTrainingtData) weight.getTrainingData();
        int sign = sign(resilientWeightTrainingtData.previousGradient * resilientWeightTrainingtData.gradient);
        if (sign > 0) {
            double min = Math.min(resilientWeightTrainingtData.previousDelta * this.increaseFactor, this.maxDelta);
            double sign2 = sign(resilientWeightTrainingtData.gradient);
            Double.isNaN(sign2);
            d = sign2 * min;
            resilientWeightTrainingtData.previousDelta = min;
        } else if (sign < 0) {
            double max = Math.max(resilientWeightTrainingtData.previousDelta * this.decreaseFactor, this.minDelta);
            d = -resilientWeightTrainingtData.previousWeightChange;
            resilientWeightTrainingtData.gradient = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
            resilientWeightTrainingtData.previousGradient = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
            resilientWeightTrainingtData.previousDelta = max;
        } else if (sign == 0) {
            double d2 = resilientWeightTrainingtData.previousDelta;
            double sign3 = sign(resilientWeightTrainingtData.gradient);
            Double.isNaN(sign3);
            d = sign3 * d2;
        } else {
            d = 0.0d;
        }
        weight.value += d;
        resilientWeightTrainingtData.previousWeightChange = d;
        resilientWeightTrainingtData.previousGradient = resilientWeightTrainingtData.gradient;
        resilientWeightTrainingtData.gradient = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
    }

    public void setDecreaseFactor(double d) {
        this.decreaseFactor = d;
    }

    public void setIncreaseFactor(double d) {
        this.increaseFactor = d;
    }

    public void setInitialDelta(double d) {
        this.initialDelta = d;
    }

    public void setMaxDelta(double d) {
        this.maxDelta = d;
    }

    public void setMinDelta(double d) {
        this.minDelta = d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.nnet.learning.LMS
    public void updateNeuronWeights(Neuron neuron) {
        for (Connection connection : neuron.getInputConnections()) {
            double input = connection.getInput();
            if (input != FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE) {
                ((ResilientWeightTrainingtData) connection.getWeight().getTrainingData()).gradient += neuron.getError() * input;
            }
        }
    }
}
